import argparse
from data import *
from train import train_model 
from moxing.framework import file
import os
'''
param:cfg{
    model{
        pretrained_path,
        resnet_type,
        num_classes,
        save_path,
    }
    dataset{
        data_root,
        train_transform,
        val_transform,
        train_ratio,
    }
    hyperparam{
        batch_size,
        num_workers,
        lr,
        momentum,
        epoch,
    }
}
    cfg = {
        'model': {
            'pretrained_path': r'models\resnet101-5d3b4d8f.pth',
            'resnet_type': 'resnet101',
            'num_classes': 43,
        },
        'dataset':{
            'data_root': r"D:\develop\huawei cloud\trash_dataset\train_data",
            'train_transform': PreprocessTransform(224, (0, 0, 0)),
            'val_transform' : BaseTransform(224, (0, 0, 0)),
            'train_ratio': 0.9,
        },
        'hyperparam':{
            'batch_size': 32,
            'num_workers': 8,
            'lr': 1e-3,
            'momentum': 0.9,
            'epoch': 10,
            'steps':""
        }
    }
'''
# 创建解析
# 文件路径
parser = argparse.ArgumentParser(description="train trash",
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--train_url', type=str, default='obs://obs-test/ckpt/mnist',
                    help='the path model saved')
parser.add_argument('--data_url', type=str, default='obs://obs-test/data/', 
                    help='the training data')
parser.add_argument('--local_root', type=str, default='/cache/', 
                    help='a directory used for transfer data between local path and OBS path')
parser.add_argument('--pretrained_path', type=str, default='obs://bin-for-competition/develop/pytorch-resnet-for-huawei-modelarts/models/resnet101_pretrained.pth')
parser.add_argument('--resnet_type', type=str, default='resnet101')
# 训练超参数
parser.add_argument("--num_classes", type=int, default=43)
parser.add_argument('--train_ratio', type=float, default=0.9)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--epoch', type=int, default=8)
parser.add_argument('--steps', type=str, default="8,12,14,16")
parser.add_argument('--lr_schedule', type=str, default="multisteps")
parser.add_argument('--gamma', type=float, default=0.8)

# 解析参数
args, unkown = parser.parse_known_args()

if __name__ == '__main__':

    print(unkown)
    if not os.path.exists(args.local_root):
        os.mkdir(args.local_root)
    if not os.path.exists(os.path.join(args.local_root, 'tmp/')):
        os.mkdir(os.path.join(args.local_root, 'tmp/'))
    # copy data from cloud to local
    print('copying dataset to cache')
    if not os.path.exists(os.path.join(args.local_root, 'train_data/')):
        os.mkdir(os.path.join(args.local_root, 'train_data/'))
    file.copy_parallel(args.data_url, os.path.join(args.local_root, 'train_data/'))
    print('copying pretrained model')
    file.copy(args.pretrained_path, os.path.join(args.local_root, args.resnet_type+'_pretrained.pth'))

    steps = args.steps.split(',')
    for i in range(len(steps)):
        steps[i] = int(steps[i])

    cfg = {
        'model': {
            'pretrained_path': os.path.join(args.local_root, args.resnet_type+'_pretrained.pth'),
            'resnet_type': args.resnet_type,
            'num_classes': args.num_classes,
            'save_path': args.train_url,
            'local_root': args.local_root,
        },
        'dataset':{
            'data_root': os.path.join(args.local_root, 'train_data/'),
            'train_transform': PreprocessTransform(224,
                                        rgb_means=(138.11617731, 128.38959552, 116.94768342),
                                        rgb_std=(52.90101662, 54.29838, 56.22659914)),
            'val_transform' : BaseTransform(224,
                                        rgb_means=(138.11617731, 128.38959552, 116.94768342),
                                        rgb_std=(52.90101662, 54.29838, 56.22659914)),
            'train_ratio': args.train_ratio,
        },
        'hyperparam':{
            'batch_size': args.batch_size,
            'num_workers': args.num_workers,
            'lr': args.lr,
            'momentum': args.momentum,
            'epoch': args.epoch,
            'steps': steps,
            'lr_schedule' : args.lr_schedule,
            'gamma' : args.gamma,
        }
    }
    print(cfg)
    train_model(cfg)
